#!/usr/bin/env python3
#
# Copyright 2025 Ettus Research, a National Instruments Brand
#
# SPDX-License-Identifier: GPL-3.0-or-later
#
"""Captures samples into DRAM, then streams those to the host.

The example uses a single replay block to capture samples from multiple radios
and stores data in a file on the host.

Note: The --freq, --gain, and --antenna options can take a single value, which
is applied to all channels, or a list of values which is applied to individual
channels. Example:

    rfnoc_rx_replay_samples_to_file.py -f 1e9 -g 20 30 -c 0/Radio#0:0 0/Radio#0:1

This will use two channels from Radio 0 with a common frequency of 1 GHz.
The first channel will use a gain of 20 dB, and the second a gain of 30 dB.

A note on hardware capabilities: When downloading data from DRAM onto the host,
it is possible for the network interfaces to not be able to keep up with the
data rates generated by the USRP, and packets will be dropped.

If this happens, the following options may help:
- Make sure the system configuration is set up for highest performance, e.g.
  by adapting network driver settings (see also
  https://kb.ettus.com/USRP_Host_Performance_Tuning_Tips_and_Tricks#Increasing_Ring_Buffers)
- Fall back to a slower data rate, e.g., use 1 GbE instead of 10 or 100 GbE

Example usage:
rfnoc_rx_replay_samples_to_file.py --args addr=192.168.10.2,type=x300 --rate 10e6
                                   --output-file captured_samples.dat --numpy
                                   --freq 1e9 1.1e9 --gain 20 30 --antenna TX/RX RX2
                                   --radio-channels "0/Radio#0:0" "0/Radio#0:1"
                                   --block "0/Replay#0" --pkt-size 1472 --duration 5

Note: The argument --output-file specifies the file path to store received data
and needs to point to a location where the user has write access.
"""

import argparse
import os
import shutil
import sys
import tempfile
import time

import numpy as np
import uhd

try:
    import tqdm

    HAVE_TQDM = True
except ImportError:
    HAVE_TQDM = False

# sc16 (32-bit) samples on the USRP
BYTES_PER_SAMP = 4


# pylint: disable=too-many-arguments
def parse_args():
    """Return parsed command line args."""
    parser = argparse.ArgumentParser(
        description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
    )
    parser.add_argument(
        "--args",
        "-a",
        type=str,
        default="",
        help="""specifies the USRP device arguments, which holds
        multiple key value pairs separated by commas
        (e.g., addr=192.168.40.2,type=x300) [default = ""].""",
    )
    parser.add_argument(
        "-o",
        "--output-file",
        required=True,
        help="specifies the file path to store the captured samples [input is required].",
    )
    parser.add_argument(
        "-r",
        "--rate",
        type=float,
        help="specifies the sample rate in samples/sec [default = MCR].",
    )
    parser.add_argument(
        "-f",
        "--freq",
        type=float,
        required=True,
        nargs="+",
        help="specifies the center frequency in Hz for each channel [input is required].",
    )
    parser.add_argument(
        "-g",
        "--gain",
        type=int,
        default=[10],
        nargs="+",
        help="specifies the receive gain in dB for each channel [default = 10].",
    )
    parser.add_argument(
        "--antenna",
        type=str,
        nargs="+",
        help="specifies the antenna to use for each channel " "(e.g., TX/RX TX2) [default = None].",
    )
    parser.add_argument(
        "-d",
        "--duration",
        type=float,
        help="specifies the duration of the capture in seconds "
        "[default = sets the duration such that DRAM completely fills].",
    )
    parser.add_argument(
        "--delay",
        "-l",
        type=float,
        default=0.5,
        help="specifies the delay in seconds between issuing a capture command "
        "on the host and starting the capture [default = 0.5].",
    )
    parser.add_argument(
        "-c",
        "--radio-channels",
        default=["0/Radio#0:0"],
        nargs="+",
        help="specifies the list of radios plus their channels " '[default = "0/Radio#0:0"].',
    )
    parser.add_argument(
        "--block",
        "-b",
        type=str,
        default="0/Replay#0",
        help='specifies the Replay block to test [default = "0/Replay#0"].',
    )
    parser.add_argument(
        "--pkt-size",
        "-k",
        type=int,
        default=None,
        help="specifies the CHDR packet size in bytes to stream data to the host "
        "[default = maximum CHDR packet size for this transport layer].",
    )
    parser.add_argument(
        "-n",
        "--numpy",
        default=False,
        action="store_true",
        help="specifies whether to save the output file in NumPy format. If --numpy "
        "is not specified, the output file will be saved in binary format.",
    )
    parser.add_argument(
        "--cpu-format",
        default="sc16",
        choices=["sc16", "fc32"],
        help="specifies the cpu format for storing data [default = sc16].",
    )
    parser.add_argument(
        "--throttle",
        type=float,
        default=0.25,
        help="specifies the throttle for streaming to host in the range of (0, 1]. "
        "E.g., use 1 for maximum rate, use 0.5 for half the maximum rate, etc. "
        "Note that other factors may affect the actual rate, such as the rate of "
        "the source or the speed supported by the transport [default = 0.25].",
    )
    parser.add_argument(
        "--no-progress",
        "-N",
        action="store_true",
        help="specifies whether to display progress bar or not during data transfer.",
    )
    return parser.parse_args()


def enumerate_radios(graph, radio_chans):
    """Return a list of radio/chain pairs to use for this test."""
    radio_id_chan_pairs = [
        (r.split(":", 2)[0], int(r.split(":", 2)[1])) if ":" in r else (r, 0) for r in radio_chans
    ]
    # Sanity checks
    available_radios = graph.find_blocks("Radio")
    radio_chan_pairs = []
    for rcp in radio_id_chan_pairs:
        if rcp[0] not in available_radios:
            raise RuntimeError(f"'{rcp[0]}' is not a valid radio block ID!")
        radio_chan_pairs.append((uhd.rfnoc.RadioControl(graph.get_block(rcp[0])), rcp[1]))
    return radio_chan_pairs


def connect_radios(graph, replay, radio_chan_pairs, freqs, gains, antennas, rate):
    """Set up the replay/radio part of the graph, and configure radios."""
    if rate is None:
        rate = radio_chan_pairs[0][0].get_rate()
    print(f"Requested rate: {rate/1e6:.2f} Msps")
    actual_rate = None
    for replay_port_idx, rcp in enumerate(radio_chan_pairs):
        radio, chan = rcp
        print(
            f"Connecting {rcp[0].get_unique_id()}:{rcp[1]} to "
            f"{replay.get_unique_id()}:{replay_port_idx}"
        )
        radio.set_rx_frequency(freqs[replay_port_idx % len(freqs)], rcp[1])
        radio.set_rx_gain(gains[replay_port_idx % len(gains)], rcp[1])
        if antennas is not None:
            radio.set_rx_antenna(antennas[replay_port_idx % len(antennas)], rcp[1])
        print(
            f"--> Radio settings: fc={radio.get_rx_frequency(chan)/1e6:.2f} MHz, "
            f" gain={radio.get_rx_gain(chan)} dB, "
            f"antenna={radio.get_rx_antenna(chan)}"
        )
        radio_to_replay_graph = uhd.rfnoc.connect_through_blocks(
            graph, rcp[0].get_unique_id(), rcp[1], replay.get_unique_id(), replay_port_idx
        )
        ddc_block = next(
            (
                (x.dst_blockid, x.dst_port)
                for x in radio_to_replay_graph
                if uhd.rfnoc.BlockID(x.dst_blockid).get_block_name() == "DDC"
            ),
            None,
        )
        if ddc_block is not None:
            print(f"Found DDC block on channel {chan}.")
            this_rate = uhd.rfnoc.DdcBlockControl(graph.get_block(ddc_block[0])).set_output_rate(
                rate, rcp[1]
            )
        else:
            this_rate = rcp[0].set_rate(rate)
        if actual_rate is None:
            actual_rate = this_rate
            continue
        if actual_rate != this_rate:
            raise RuntimeError("Unexpected rate mismatch.")
    return actual_rate


def _sanitize_args(replay, num_ports, num_bytes, pkt_size_bytes=None):
    """Sanitize requested args based on the capabilities of the replay block."""
    assert num_ports <= replay.get_num_input_ports()
    ## Figure out how many bytes to send
    mem_size = replay.get_mem_size()
    mem_stride = mem_size // num_ports
    num_bytes = int(num_bytes) if num_bytes is not None else mem_stride
    print(f"Total memory size: {mem_size // 1024 // 1024} MiB")
    # Set the number of bytes to test
    print(f"Requested Record size per port: {num_bytes // 1024 // 1024} MiB")
    if num_bytes > mem_size // num_ports:
        num_bytes = mem_size // num_ports
        print(
            f"WARNING: Exceeds allocated space per port! "
            f"Reducing to {num_bytes // 1024 // 1024} MiB"
        )
    for port in range(num_ports):
        print(
            f"Port {port} address space: "
            f"0x{mem_stride*port:08X} - 0x{mem_stride*port+num_bytes:08X}"
        )
        replay.set_play_type("sc16", 0)
        replay.set_record_type("sc16", 0)
        if pkt_size_bytes is not None:
            replay.set_max_items_per_packet(pkt_size_bytes // BYTES_PER_SAMP, port)
    return mem_stride, num_bytes


def run_capture(graph, replay, radio_chan_pairs, mem_stride, num_samps, rate, cap_delay):
    """Record from radio into DRAM."""
    num_bytes = num_samps * BYTES_PER_SAMP
    num_ports = len(radio_chan_pairs)
    ## Arm replay block for recording
    for idx in range(len(radio_chan_pairs)):
        replay.record(idx * mem_stride, num_bytes, idx)
    ## Send stream command to all radios
    # This 'rate ratio' would be better handled by RFNoC. If the replay block
    # were to submit the stream command to the radio, this would not be necessary.
    rate_ratio = int(radio_chan_pairs[0][0].get_rate() / rate)
    stream_cmd = uhd.types.StreamCMD(uhd.types.StreamMode.num_done)
    stream_cmd.num_samps = num_samps * rate_ratio
    stream_cmd.stream_now = False
    stream_cmd.time_spec = graph.get_mb_controller().get_timekeeper(
        0
    ).get_time_now() + uhd.types.TimeSpec(cap_delay)
    print(f"Requesting {num_samps} samples from {num_ports} radio(s)...")
    print(f"Capture will take approx. {num_samps/rate:.1f} seconds...")
    for rcp in radio_chan_pairs:
        rcp[0].issue_stream_cmd(stream_cmd, rcp[1])
    ## Wait for record buffers to fill up
    timeout = time.monotonic() + num_samps / rate + cap_delay + 2.0
    if HAVE_TQDM:
        with tqdm.tqdm(total=num_bytes * num_ports, unit_scale=True, unit="byte") as pbar:
            total_bytes_recorded = 0
            while total_bytes_recorded < num_ports * num_bytes:
                bytes_recorded = sum(
                    [replay.get_record_fullness(port) for port in range(num_ports)]
                )
                pbar.update(bytes_recorded - total_bytes_recorded)
                total_bytes_recorded = bytes_recorded
                time.sleep(0.100)
                if time.monotonic() > timeout:
                    raise RuntimeError("Timeout while loading replay buffer!")
    else:
        while any((replay.get_record_fullness(port) < num_bytes for port in range(num_ports))):
            time.sleep(0.200)
            if time.monotonic() > timeout:
                raise RuntimeError("Timeout while loading replay buffer!")
    return num_bytes // BYTES_PER_SAMP


def rx_data_to_host(replay, rx_streamer, output_data, mem_stride, num_samps, pkt_size_bytes):
    """Download data from previously configured replay block via streamer object."""
    print("Downloading data to host...")
    rx_md = uhd.types.RXMetadata()
    num_ports = rx_streamer.get_num_channels()
    num_bytes = num_samps * BYTES_PER_SAMP
    max_samps_per_pkt = pkt_size_bytes // BYTES_PER_SAMP
    mem_stride = replay.get_mem_size() // num_ports
    # Configure playback regions
    for idx in range(num_ports):
        replay.config_play(idx * mem_stride, num_bytes, idx)
    stream_cmd = uhd.types.StreamCMD(uhd.types.StreamMode.num_done)
    stream_cmd.num_samps = num_samps
    # This is not strictly necessary, but the streamer will not allow a
    # multi-chan operation without a time spec.
    stream_cmd.stream_now = False
    stream_cmd.time_spec = uhd.types.TimeSpec(0.0)
    rx_streamer.issue_stream_cmd(stream_cmd)
    if HAVE_TQDM:
        num_rx = 0
        output_buf = np.zeros((num_ports, 1000 * max_samps_per_pkt), dtype=output_data.dtype)
        with tqdm.tqdm(total=num_bytes * num_ports, unit_scale=True, unit="byte") as pbar:
            while num_rx < num_samps:
                num_rx_i = rx_streamer.recv(output_buf, rx_md, 1.0)
                if rx_md.error_code == uhd.types.RXMetadataErrorCode.timeout:
                    print("recv() timed out. Exiting...")
                    break
                if (
                    rx_md.error_code == uhd.types.RXMetadataErrorCode.overflow
                    and rx_md.out_of_sequence
                ):
                    print("Detected sequence error!")
                elif rx_md.error_code == uhd.types.RXMetadataErrorCode.overflow:
                    print("ERROR: Overflow detected!")
                elif rx_md.error_code != uhd.types.RXMetadataErrorCode.none:
                    print("ERROR: recv() gave unexpected error code: " + rx_md.strerror())
                pbar.update(num_rx_i * BYTES_PER_SAMP * num_ports)
                output_data[:, num_rx : num_rx + num_rx_i] = output_buf[:, 0:num_rx_i]
                num_rx += num_rx_i
    else:
        num_rx = rx_streamer.recv(output_data, rx_md, 5.0)
        if rx_md.error_code != uhd.types.RXMetadataErrorCode.none:
            print("Error during download: " + rx_md.strerror())
    if num_rx != num_samps:
        print("ERROR: Fewer samples received than expected!")
    print("Download complete.")
    return num_rx


def main():
    """Run capture."""
    global HAVE_TQDM
    args = parse_args()
    if args.no_progress:
        HAVE_TQDM = False
    graph = uhd.rfnoc.RfnocGraph(args.args)
    replay = uhd.rfnoc.ReplayBlockControl(graph.get_block(args.block))
    radio_chan_pairs = enumerate_radios(graph, args.radio_channels)
    rate = connect_radios(
        graph, replay, radio_chan_pairs, args.freq, args.gain, args.antenna, args.rate
    )
    print(f"Using rate: {rate/1e6:.3f} Msps")
    # Set up streamer
    stream_args = uhd.usrp.StreamArgs(args.cpu_format, "sc16")
    stream_args.args["throttle"] = str(args.throttle)
    rx_streamer = graph.create_rx_streamer(len(radio_chan_pairs), stream_args)
    num_ports = rx_streamer.get_num_channels()
    for chan in range(len(radio_chan_pairs)):
        # This won't work if we can't directly attach the streamer to the
        # replay block.
        graph.connect(replay.get_unique_id(), chan, rx_streamer, chan)
    graph.commit()

    num_bytes = args.duration * rate * BYTES_PER_SAMP if args.duration is not None else None
    mem_stride, num_bytes = _sanitize_args(
        replay,
        len(radio_chan_pairs),
        num_bytes,
        pkt_size_bytes=args.pkt_size,
    )
    num_samps = run_capture(
        graph, replay, radio_chan_pairs, mem_stride, num_bytes // BYTES_PER_SAMP, rate, args.delay
    )

    if args.numpy:
        tmp_dir = tempfile.mkdtemp()
        mmap_filename = os.path.join(tmp_dir, "replay_capture.dat")
    else:
        mmap_filename = args.output_file
    cap_dtype = np.complex64 if args.cpu_format == "fc32" else np.uint32

    output_data = np.memmap(mmap_filename, shape=(num_ports, num_samps), mode="w+", dtype=cap_dtype)
    output_data.flush()

    rx_data_to_host(
        replay, rx_streamer, output_data, mem_stride, num_samps, replay.get_max_packet_size(0)
    )
    if args.numpy:
        print(f"Saving data as Numpy array to {args.output_file}...")
        with open(args.output_file, "wb") as out_file:
            np.save(out_file, output_data)
        output_data = None
        shutil.rmtree(tmp_dir)
    return True


if __name__ == "__main__":
    sys.exit(not main())
